from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from recognition_server import RecognitionServer
from task_executor import task_executor
import logging
from typing import Dict, Any, Optional
import json
import os

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    force=True  # 强制应用这个配置
)
logger = logging.getLogger(__name__)

# 清除所有已存在的处理器
for name in ['mermaid_agent', 'recognition_server', 'task_executor']:
    module_logger = logging.getLogger(name)
    for handler in module_logger.handlers[:]:
        module_logger.removeHandler(handler)

app = FastAPI()

# 添加CORS中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 初始化recognition server
recognition_server = RecognitionServer()

class PipelineRequest(BaseModel):
    query: str
    problem_content: str = ""
    editor_code: str = ""

class PipelineService:
    def __init__(self):
        pass

    async def process_request(self, request: PipelineRequest) -> Dict[str, Any]:
        """处理请求的核心逻辑"""
        try:
            # 1. 设置上下文
            task_executor.set_problem_content(request.problem_content)
            task_executor.set_editor_code(request.editor_code)
            
            # 2. 意图识别和处理
            intent_result = recognition_server.process_request(
                query=request.query,
                problem_content=request.problem_content,
                editor_code=request.editor_code
            )
            
            if not intent_result:
                raise HTTPException(status_code=400, detail="意图识别失败")
                
            return intent_result
                
        except Exception as e:
            logger.error(f"处理请求时出错: {str(e)}")
            raise HTTPException(status_code=500, detail=str(e))

pipeline_service = PipelineService()

@app.post("/process")
async def process_pipeline(request: PipelineRequest):
    """处理pipeline请求的端点"""
    logger.info("收到请求:")
    logger.info(f"查询: {request.query}")
    logger.info(f"编辑器代码: {request.editor_code}")
    logger.info(f"题目内容: {request.problem_content}")
    return await pipeline_service.process_request(request)

@app.get("/process/stream")
async def process_pipeline_stream(
    query: str,
    problem_content: str = "",
    editor_code: str = ""
):
    """处理pipeline请求的流式端点"""
    try:
        logger.info("收到流式请求:")
        logger.info(f"查询: {query}")
        logger.info(f"编辑器代码: {editor_code}")
        logger.info(f"题目内容: {problem_content}")

        async def generate():
            try:
                async for chunk in recognition_server.process_request_stream(
                    query=query,
                    problem_content=problem_content,
                    editor_code=editor_code
                ):
                    if chunk:
                        yield f"data: {json.dumps({'chunk': chunk})}\n\n"
                yield f"data: {json.dumps({'chunk': '[DONE]'})}\n\n"
            except Exception as e:
                error_msg = f"处理流式请求时出错: {str(e)}"
                logger.error(error_msg)
                yield f"data: {json.dumps({'error': error_msg})}\n\n"

        return StreamingResponse(
            generate(),
            media_type="text/event-stream"
        )

    except Exception as e:
        error_msg = f"启动流式处理时出错: {str(e)}"
        logger.error(error_msg)
        raise HTTPException(status_code=500, detail=error_msg)

@app.get("/health")
async def health_check():
    """健康检查端点"""
    return {"status": "healthy"}

@app.get("/test")
async def test_page():
    """返回测试页面"""
    current_dir = os.path.dirname(os.path.abspath(__file__))
    return FileResponse(os.path.join(current_dir, "test.html"))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)
